package com.xzixi.websocket.interceptablewebsocketdemo.websocket;

import com.xzixi.websocket.interceptablewebsocket.interceptor.FromClientInterceptor;
import com.xzixi.websocket.interceptablewebsocket.util.MessageFromClient;
import com.xzixi.websocket.interceptablewebsocket.util.MessageMatcher;
import com.xzixi.websocket.interceptablewebsocket.util.NotifyUtil;
import com.xzixi.websocket.interceptablewebsocketdemo.entity.Resource;
import com.xzixi.websocket.interceptablewebsocketdemo.entity.User;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.MessageChannel;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;

import java.security.Principal;
import java.util.List;

/**
 * websocket消息授权决策
 * @author 薛凌康
 */
public class AccessDecisionFromClientInterceptor implements FromClientInterceptor {

    @Autowired
    private MessageMatcher messageMatcher;
    @Autowired
    private NotifyUtil notifyUtil;

    @Override
    public boolean preHandle(WebSocketSession session, Principal principal, MessageFromClient message,
                             MessageChannel outputChannel, StompSubProtocolHandler handler) {
        String type = message.getType();
        // 以下类型的消息不需要拦截
        if ("CONNECT".equals(type) || "CONNECT_ACK".equals(type) || "HEARTBEAT".equals(type) || "UNSUBSCRIBE".equals(type)
                || "DISCONNECT".equals(type) || "DISCONNECT_ACK".equals(type) || "OTHER".equals(type)) {
            return true;
        }
        // 禁止未登录用户访问
        User user = (User) session.getAttributes().get("user");
        if (user==null) {
            return false;
        }
        // 检查用户权限
        if (checkResources(message, user.getResources())) {
            return true;
        }
        // 通知用户，如果传对象，客户端会受到json格式数据
        notifyUtil.sendMessage(session, "没有权限");
        return false;
    }

    // 检查权限
    private boolean checkResources(MessageFromClient message, List<Resource> resources) {
        for (Resource resource : resources) {
            String protocol = resource.getProtocol();
            String pattern = resource.getPattern();
            String type = resource.getType();
            if ("ws".equals(protocol) && messageMatcher.matches(pattern, type, message)) {
                return true;
            }
        }
        return false;
    }

}
